[Perf] Support Flashinfer RoPE+Quant+KV update kernel for trtllm_mha backend for GPT-OSS#15729
[Perf] Support Flashinfer RoPE+Quant+KV update kernel for trtllm_mha backend for GPT-OSS#15729elvischenv wants to merge 10 commits intosgl-project:mainfrom
Conversation
|
Warning You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again! |
5e5c50f to
7cc00cb
Compare
7cc00cb to
191dcf2
Compare
191dcf2 to
e59267f
Compare
|
This can be reviewed together with #19451 . They are very similar, except that one is for trtllm_mha and one is for trtllm_mla |
|
For the accuracy results, which model are you testing on? |
| return None | ||
|
|
||
| def support_rope_fusion(self) -> bool: | ||
| """Check if the current backend supports RoPE fusion.""" |
There was a problem hiding this comment.
Instead of adding this method in base class, can we control this fusion with an environ flag?
Now it is set to False by default. After this feature stabilizes it can be turned on by default
e59267f to
dc6a0ac
Compare
elvischenv
left a comment
There was a problem hiding this comment.
@Fridge003 Updated the testing results in the PR description. This PR currently depends on a Flashinfer PR flashinfer-ai/flashinfer#2792 to fix the compatibility issue with piecewise cudagraph.
Motivation
This PR is to support Flashinfer
rope_quantize_fp8_append_paged_kv_cachekernel for trtllm_mha backend and enable it on GPT-OSS.Depends on a Flashinfer PR to fix the piecewise cudagraph compatibility issue: flashinfer-ai/flashinfer#2792
Tested cmd
server:
server with eagle:
client(accuracy):
client(benchmark TP8 conc8):
Accuracy Results
PR
PR with eagle
main
main with eagle
Perf (GPT-OSS-120b TP8 con8)
PR: about 7% perf gain
main
Eagle Accept length
PR:
main:
Modifications
trtllm_mha_backend.py: support corerope_quantize_fp8_append_paged_kv_cachekernelgpt_oss.py: defer RoPE into attention backendradix_attention.py: defer RoPE into attention backendenviron.py: addSGLANG_ENABLE_FLASHINFER_ROPE_FUSION, by default disabledtest_trtllm_mha_backend.py: test trtllm mha backend, including basic and rope fusion functionalitytest_gpt_oss_models_rope_fusion.py: test gpt-oss e2e accuracy with rope fusion enabledChecklist